{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gensim.models import Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./3.png\" width=600 height=500 >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = nx.Graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.add_nodes_from([ str(i) for i in range(1, 10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.add_node('A')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.add_node('B')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NodeView(('1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B'))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges = [\n",
    "    (1, 2), (1, 4), (1, 5), (2, 3), (3, 4), (4, 5), (3, 6), (6, 7), (7, 8), (8, 9), (7, 'A'), ('A', 'B')\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges = [ (str(item[0]), str(item[1])) for item in edges]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('1', '2'),\n",
       " ('1', '4'),\n",
       " ('1', '5'),\n",
       " ('2', '3'),\n",
       " ('3', '4'),\n",
       " ('4', '5'),\n",
       " ('3', '6'),\n",
       " ('6', '7'),\n",
       " ('7', '8'),\n",
       " ('8', '9'),\n",
       " ('7', 'A'),\n",
       " ('A', 'B')]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.add_edges_from(edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<networkx.classes.graph.Graph at 0x2079f6c6748>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAb4AAAEuCAYAAADx63eqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAxqUlEQVR4nO3deXjU5bn/8fcwSUgGtUiRNQYEXMDIZgRZgiiIYrVSoChagYC77NqfolZM1eM5YimGioIVEBUOcqqgYEECgiCCIKCAYBsQAmFPUJZkSCb5/v54BMGSmUlmMt9ZPq/ryhUbvvPkDtW5cz/b7bAsy0JERCRGVLM7ABERkVBS4hMRkZiixCciIjFFiU9ERGKKEp+IiMQUJT4REYkpSnwiIhJT4uwOIKJ5PLBzJ7jdkJgIjRtDnP5KRUTCmd6lKyo/H6ZOhWnTYPt2iI8HpxNKS6G4GJo1g4wMGDIEatWyO1oREfkFh25u8VNxMWRmwvjx4HBAUVH5zyYlgWXB6NEwdiwkJIQuThER8UqJzx+5udC9O+TlQWGh/69zuaBhQ8jOhpSUqotPRET8psTnS24upKVBQYGZzqwop9NMea5bp+QnIhIGlPi8KS6G1FTYsaNySe8UpxOaNIEtW8yaoIiI2EbHGbzJzDTTm4EkPTCvz8sz44mIiK1U8ZUnPx+Sk81RhXNoDBwAnEA80BF4HbjY25iJiSYBareniIhtVPGVZ+pUs3vTi4+A48A+oC4wzNeYDge8+WZQwhMRkcpR4ivPtGnejyycIRHoC3zr68GiIpg+PbC4REQkIEp85+LxmMPpfioEZgPX+vNwTo4ZX0REbKGbW85l506z+7K42OtjvTB/gSeAi4BF/owdH2/Gb9YssBhFRKRSVPGdi9ttjiD4MBf4AXADfwOuA/b7epHTWe6GGRERqXpKfOeSmFihIwxOoPdPn1f6eri01IwvIiK20FTnuTRuDCUlfj9uAR8CR4Dmvh4uKTHji4iILZT4ziUuDpo2ha1bvT52G6bKcwCNgLeAK32N3ayZWheJiNhIU53lycgwXRbKsRMowpzjOwZsBu72NWZSkhlXRERso5tbylNQYDorBHMjim5uERGxnSq+8tSqZfrpuVzBGc/lgkcfVdITEbGZKj5v1J1BRCTqqOLzJiHBNJGtVcuvc33ndKofX3a2kp6ISBhQ4vMlJcU0kW3SpOLTni6XeZ2a0IqIhA0lPn+kpMDmzTBypNmg4mW3J2ASXmIijBplpjeV9EREwobW+CqqoMC0Fpo+3Vw4HR9vpjNLS7FKStjm8XDxM89w3rBh2sgiIhKGlPgC4fGYC6fdblPhNW7Mrb16MWDAAPr162d3dCIicg6a6gxEXJy5iSU19fSNLN27d2fJkiV2RyYiIuVQ4guybt26kZ2dbXcYIiJSDiW+IEtNTeX48eN8//33dociIiLnoMQXZA6HQ9OdIiJhTImvCmi6U0QkfGlXZxXIzc0lLS2N/fv3U62afrcQEQkneleuAikpKdSsWZNNmzbZHYqIiPyCEl8V0XSniEh4UuKrItrgIiISnrTGV0UKCgpo3Lgxhw8fJiEhwe5wRETkJ6r4qkitWrW47LLLWLNmjd2hiIjIGZT4qlD37t21ziciEmaU+KqQNriIiIQfrfFVoaKiIi666CL27t3LBRdcYHc4IiKCKr4qlZSURLt27fjss8/sDkVERH6ixFfFtM4nIhJelPiqmM7ziYiEF63xVbHS0lJq167N1q1bqVevnt3hiIjEPFV8VczpdNK1a1eWLl1qdygiIoISX0honU9EJHwo8YXAqfN8mlUWEbGfEl8IXH755ZSVlZGTk2N3KCIiMU+JLwQcDoducRERCRNKfCGiYw0iIuFBxxlCZO/evaSmpnLo0CGcTqfd4YiIxCxVfCHSoEED6tWrx4YNG+wORUQkpinxhZCmO0VE7KfEF0La4CIiYj+t8YXQjz/+SHJyMocOHSIxMdHucEREYpIqvhD61a9+RWpqKqtWrbI7FBGRmKXEF2Ka7hQRsZcSX4hpg4uIiL20xhdiJ0+epHbt2uzevZuaNWvaHY6ISMxRxRdi1atXp2PHjixbtszuUEREYpISnw3UpkhExD5KfDbQBhcREfso8dmgdevWHD58mD179tgdiohIzFHis0G1atW44YYbtLtTRMQGSnw26datmxKfiIgNdJzBJtu3byc9PZ28vDwcDofd4YiIxAxVfDZp0qQJ1atXZ+vWrXaHIiISU5T4bOJwODTdKSJiAyU+G+k8n4hI6GmNz0YHDx7ksssu4/Dhw8TFxdkdjohITFDFZ6M6derQqFEj1q1bZ3coIiIxQ4nPZqenOz0eyMmBzZvNZ4/H7tBERKKSpjrtlJ/P1j/+kaTZs2ns8UB8PDidUFoKxcXQrBlkZMCQIVCrlt3RiohEBSU+OxQXQ2YmjB+P5XDgKCoq/9mkJLAsGD0axo6FhITQxSkiEoWU+EItNxe6d4e8PCgs9P91Lhc0bAjZ2ZCSUnXxiYhEOSW+UMrNhbQ0KCgw05kV5XSaKc9165T8REQqSYkvVIqLITUVduyoXNI7xemEJk1gyxazJigiIhWiXZ2hkplppjcDSXpgXp+XZ8YTEZEKU8UXCvn5kJwMbrfXx7oCXwP7geq+xkxMNAlQuz1FRCpEFV8oTJ0KPjow7ARWAA7gQ3/GdDjgzTcDDk1EJNao4guFFi3ARxeGPwOLgPbAv4D5/o67ZUvA4YmIxBIlvqrm8UCNGmZzixfNgNGYxHctsAeo62vshAQ4cQJ0z6eIiN801VnVdu70uftyJbAL6AdcDTQFZvozdny8GV9ERPymxFfV3G5zBMGLt4AeQO2f/vddP33NJ6fT54YZERE5m+bIqlpiotcjDEXAe0ApUO+nr50EfsDs8GzlbezSUjO+iIj4TRVfVWvcGEpKyv3juYAT+BbY+NPHViAdmOFr7JISM76IiPhNia+qxcVB06bl/vFbQAaQgqn4Tn0MBd4FvDYnatZMG1tERCpIiS8UMjJMl4VzWAj85Rxf74c5yF5uWktKMuOKiEiF6DhDKBQUmM4KwdyIoptbREQqRRVfKNSqZfrpuVzBGc/lgkcfVdITEakEVXyhou4MIiJhQRVfqCQkmCaytWr5PNdXrlP9+LKzlfRERCpJiS+UUlJME9kmTSo+7elymdepCa2ISECU+EItJQU2b4aRI80GlXJ2e55SEh9vnhs1ykxvKumJiAREa3x2KigwrYWmT4ecHDN96XSaNcCSEo7Xr8+UkhJGb9qkjSwiIkGixBcuPB5z4bTbbSq8xo0pq1aNSy65hHnz5tG6dWu7IxQRiQpKfGHumWee4ejRo0yYMMHuUEREooISX5jbvn07HTp0YM+ePSQkJNgdjohIxNPmljDXtGlTmjdvzscff2x3KCIiUUGJLwIMGjSI6dOn2x2GiEhU0FRnBDh27BgpKSl899131KlTx+5wREQimiq+CHD++edz++238+6779odiohIxFPiixCDBg1i2rRpqEAXEQmMEl+E6NKlC8eOHWPjxo12hyIiEtGU+CJEtWrVGDhwINOmTbM7FBGRiKbNLRHk+++/p127duzZs4fq1avbHY6ISERSxRdBLrnkElJTU1mwYIHdoYiIRCwlvgijM30iIoHRVGeEOX78OBdffDFbt26lXr16docjIhJxVPFFmPPOO4/f/e53OtMnIlJJSnwR6NR0p4p1EZGKU+KLQJ07d6awsJD169fbHYqISMRR4otAOtMnIlJ52twSoXbt2sXVV19NXl6ezvSJiFSAKr4I1ahRI1q1asVHH31kdygiIhFFiS+C6UyfiEjFaaozgp04cYLk5GS+/fZb6tevb3c4IiIRQRVfBKtRowa9e/fmnXfesTsUEZGIocQX4TIyMnSmT0SkApT4IlynTp04efIk69atszsUEZGIoMQX4RwOx+nu7CIi4ps2t0SB3Nxc2rRpQ15eHomJiXaHIyIS1lTxRYGUlBTatm3Lhx9+aHcoIiJhT4kvSuhMn4iIfzTVGSUKCwtJTk5m06ZNNGzY0O5wRETCliq+KOFyuejTp4/O9ImI+KDEF0V0pk9ExDclvijSoUMHSktL+fLLL+0ORUQkbCnxRRGd6RMR8U2bW6LM7t27ad26NXv27CEpKcnucEREwo4qvihz8cUXk5aWxrx58+wORUQkLCnxRSGd6RMRKZ+mOqNQUVERDRs25JtvviE5OdnucEREwooqviiUlJTE73//e95++227QxERCTuq+KLU6tWrGThwINu2bcPhcNgdjohI2FDFF6Xat2+Pw+Fg9erVdociIhJWlPiilM70iYicm6Y6o1heXh5XXXUVe/bsweVy2R2OiEhYUMUXxRo2bEi7du2YO3eu3aGIiIQNJb4od+riahERMTTVGeXcbjcNGzZkw4YNpKSk2B2OiIjtVPFFucTERPr166czfSIiP1HFFwPWrFnDH/7wB/71r3/pTJ+IxDxVfDGgXbt2xMfHs2rVKrtDERGxnRJfDNCZPhGRn2mqM0bs3buXK6+8kj179lCjRg27wxERsY0qvhjRoEEDOnTowAcffGB3KCIitlLiiyE60ycioqnOmOJ2u0lOTuarr76iUaNGdocjImILVXwxJDExkTvuuIMZM2bYHYqIiG1U8cWYtWvXcuedd5KTk6MzfSISk1TxxZi0tDSSkpJYuXKl3aGIiNhCiS/G6EyfiMQ6TXXGoH379tGiRQt2797NeeedZ3c4IiIhpYovBtWvX59OnTrx/vvv2x2KiEjIKfHFqLPO9Hk8kJMDmzebzx6PrbGJiFQlTXXGqJN79/LipZfyZIMGJOTmQnw8OJ1QWgrFxdCsGWRkwJAhUKuW3eGKiASNEl+sKS6GzEwYP56THg/VvVV3SUlgWTB6NIwdCwkJoYtTRKSKKPHFktxc6N4d8vKgsND/17lc0LAhZGeDuriLSIRT4osVubmQlgYFBWY6s6KcTjPluW6dkp+IRDQlvlhQXAypqbBjR+WS3ilOJzRpAlu2mDVBEZEIpF2dsSAz00xvBpL0wLw+L8+MJyISoVTxRbv8fEhOBre73Ed+eYS9CHgYmFjeCxITTQLUbk8RiUCq+KLd1Kng4zLq42d87AeSgN97e4HDAW++GawIRURCSokv2k2bBkVFfj/+D6AOkO7toaIiUENbEYlQmuqMZh4P1KhhNrf46QagC/CsrwcTEuDECYiLq3x8IiI2UMUXzXburNDuy13AcmCgPw/Hx5vxRUQijBJfNHO7zREEP70NdAYu8edhp9PrhhkRkXClxBfNEhMrdIRhBn5We4BVWmrGFxGJMFrji2YVWONbBdyI2dV5vh9DnwT63XILXW64ga5du9K6dWucFaguRUTsooovmsXFQdOmfj36FtAb/5IeQLXLLuPugQPZvn0799xzD7Vr1+a3v/0t48ePZ/369ZQGelheRKSKqOKLduPGmc4KFTjS4FNSEvz5z/DYY6e/dODAAZYvX86yZctYtmwZ+/btIz09na5du9K1a1datWqlilDEF4/HbBpzu81SQuPG2jldBZT4ol1BgemsEMyNKH7c3LJ///6zEuH+/fvp0qXL6UTYsmVLJUIRMLcrTZ1qztxu367emCGgxBcLnnoKJkyoWCui8rhcMGoUPP98hV72y0R44MCBsypCJUKJOWf0xsTh8D4ro96YQaXEFwuKiyls0oSEvDwCmjQJYneGffv2nZUIDx48+B8VYbVqWoKWKKXemLZS4osBH3/8MU/dcw9ryspIOHYsLPvxnUqEn376KcuWLePQoUOnE+H111/PVVddpUQo0UG9MW2nxBflPvjgAx544AE+/PBDrm3QIGJ+y9y7d+9ZFeHhw4fPqgjDOhFqg4KUR70xw4ISXxSbPXs2I0aM4OOPP6Zt27bmixVZV3C5oKwMHn3UrCvY+B/YqUR4qiLMz88/qyJMTU21NxFqg4L4IwzW20WJL2q99dZbjBkzhkWLFnHVVVf95wMFBaa10PTpkJNz9ht1ScnPb9SDB4flG3VeXt5ZFWF+fj7XXXfd6YowZIlQGxTEXz56Y57E9MHMBgqApsCLQE9vY6o3ZqUo8UWhKVOm8Nxzz7F48WKuuOIK3y+Igqm5U4nwVEV45MgRunTpwvXXX0/Xrl258sorg58ItUFBKsLHmdoTwDhgEJACfAz0BzYBjcsbMynJ/OL1xz8GO9qopsQXZbKyshg/fjxLliyhqZ+3tkSjPXv2nFURHjly5KyKMOBEqA0KUlEtWsDWrRV6SUtgLNDH17hbtgQQWOxR4osiL730EpMnT2bp0qU0atTI7nDCyu7du89KhD/++ONZibBFixb+J0JtUJCKqkRvzANAI2Aj4HXeRr0xK0yJLwpYlsVzzz3HzJkzWbJkCQ0bNrQ7pLAXUCLUBgWpqJwcaN3aJCg/lGDW9poCk309XKMGbNxo1uXFL0p8Ec6yLJ588knmz59PdnY2devWtTukiJSbm3tWIjx27Nh/JEKHw+Fzg8Ip/wtkArlAPWA6kF7ew9qgEP02b4ZOneDoUZ+PlgF3AUeBeYDPuYALLoDPPzezEOIXJb4IZlkWo0aN4rPPPuOTTz6hdu3adocUNcpLhCNLSuiwaBHVvCS+xcC9wGygHbDvp6+XW4drg0L087Pis4DBwE7M5pYkf8ZWxVdhSnwRqqysjEceeYQNGzbwz3/+kwsvvNDukKLarl27WL58Od2GD6fhjz96fbYjMOSnD79pg0J083ON70HMml42cJ6/Y2uNr8KU+CJQaWkp9957Lzk5OSxYsIALLrjA7pBigx9vXqWY39L/DPwdcAO9MNvUvf72rjevqGe1aIHDy67OXZhjC9XhrDt1JwN3extYvzRVWJje+STlKSkp4Z577iE3N5eFCxcq6YXSzp0+d18ewGxM+D9gBea39w2Az60r8fFmfIk6hYWFvPHGG4wvKKDI4Sj3uUaYqU43cPyMD69JLynJXDQhFaLEF0GKi4u58847+eGHH5g/fz41atSwO6TY4nabIwhenKrqhgH1gdrAaMx6jVdOZ3B7Jortdu3axeOPP06jRo346KOPuHrSJBKrVw/uN7Esc7uSVIgSX4Rwu9307t2b0tJSPvjgA5KS/Fr2lmBKTPR5bu9CIBk48/f68n/H/5nl8ZjxJaJZlsXy5cvp06cPbdu2paSkhNWrV/Phhx/StXdvHKNHmyMswfheLpe5R1e7gStMa3wRoLCwkF69elGrVi3efvtt4nXY2R5+blB4BvgnsACzFf23QFfgOS+vOQl0aduWTtddR3p6Op07d+aiiy4KTtxS5YqKipg5cyZZWVmcPHmS4cOHM2DAAM477xdbVIJ0+UGpw8He6tWpnpNDHZ3brTBVfGHu2LFj9OzZk/r16/Puu+8q6dkpLg78uAbuT8A1wGVAc6AN8JSP18Q3b87LEyZQu3ZtpkyZQrNmzWjevDn3338/77zzDrt27Qo4fAm+3bt3M2bMGFJSUvjggw8YN24c3377LQ8//PB/Jj0wm5iys02V5mPavFxOJ9Vq12b2fffRrlMn1q9fH9gPEYssCVtHjhyxrr32Wuu+++6zSktL7Q5HLMuyXnrJspKSLMusrgTnIynJssaNO+vbeDwea/369daECROsPn36WHXq1LEuvvhi66677rJee+01a8uWLVZZWZlNfwmxrayszPrss8+svn37WhdeeKE1YsQI61//+lfFBtm1y7IuvdSyXK6K/bvicpnX7dplWZZlzZkzx6pdu7Y1a9asKvhJo5emOsNUfn4+N910Ex07duSVV14xt4aI/QoKTGeFYG5E8ePmFsuy+Pe//82KFSv47LPPWLFiBUePHqVz586kp6fTpUsX2rRpQ5yOQ1QZt9vNrFmzyMrKorCwkGHDhjFw4EDOP//8yg0YpN6Y33zzDb169aJfv3688MILOCtbScYQJb4wdPDgQbp3707Pnj357//+byW9cBMmd3Xm5eWxYsWK0x87d+6kffv2pKenk56eTvv27XEFaSNFLNuzZw+vvfYab7zxBmlpaQwfPpwePXoEr81VEHpjHj58mH79+pGYmMjMmTOpWbNmcGKLUkp8YWbv3r1069aNfv368eyzzyrphaMw7c5QUFDA559/fjoRfvPNN7Rq1ep0IuzUqZNu+PGTZVmsWrWKiRMn8sknn/CHP/yBRx55hMsvv7xqv3EAvTFLSkp47LHHWLhwIXPnzqV58+ZVGmokU+ILI7m5uXTr1o3BgwczZswYu8MRbyKgH9+JEydYs2bN6US4Zs0aLrnkErp06XI6GTZo0KBKvnfQhahZstvtZvbs2WRlZXH06FGGDRvGoEGDIuqiiGnTpvH4448zdepUbr31VrvDCUtKfGFix44ddOvWjREjRjBy5Ei7wxF/RFgH9pKSEtavX386Ea5cuZKaNWuelQibNWsWPrMM+fkwdSpMmwbbt589BVhc/PMU4JAhAZ9l27t3L6+99hpTpkyhTZs2DB8+nJtvvjl405khtnr1avr27cvDDz/MmDFjwuf/03Bhz54aOdO2bdus5ORka9KkSXaHIhV18qRlPfmkZSUm+t7t6XKZ5556yrKKi+2O3CotLbU2b95sTZo0yerfv7+VnJxs1atXz+rbt6+VlZVlbdiwwfJ4PKEPrCJ/p0lJ5rknnzSvq4CysjLriy++sPr3729deOGF1iOPPGJt3bq1in6o0MvLy7Pat29v/f73v7eOHz9udzhhRRWfzTZv3sxNN93Ec889x2BdPRS5grBBwW6WZbFr166zdo7u37+fjh07nt45mpaWRvVgX7t1phBU0SdPnmTOnDlkZWWRn59/ejozGjeEuN1uHnzwQTZs2MDcuXO55JJL7A4pLCjx2WjDhg307NmT8ePHc9ddd9kdjgRLiNajQuHgwYOsXLnydDL87rvvuPrqq09PjXbs2LHy2/l/qYrXTfft28fkyZOZPHkyV111FcOHD6dnz55Rv/3fsiyysrJ48cUXmTVrFtdff73dIdlOic8mX375JbfddhuTJk2iT58+docj4pejR4/yxRdfnF4n/Oqrr7jiiitOJ8LOnTtTp06dig9chTtlv/zyS7KysliwYAH9+/dn6NChtGjRovLfI0ItWbKEu+++m6eeeoqhQ4fG9LqfEp8NVq5cSe/evbXrSiLeyZMnWbt27elEuGrVKurXr386Eaanp9OoUSPfb7JBPhvpGT6c9666iqysLA4cOMCwYcPIyMiI+eMcO3bsoFevXlxzzTVMmjSpaqetw5gSX4gtXbqUO+64g3fffZcePXrYHY5IUJWWlrJp06bTa4QrVqwgPj7+rJ2jzZs3P3u3ZH4+JCd7vQ3nD8AS4ARQD/h/wL1e4nADd6ank/Hoo9x6661RP51ZEcePH2fQoEHk5eXx/vvvU79+fbtDCjklvhBauHAhAwYMYM6cOVx33XV2hyNS5SzLIicn56xE+OOPP9KpU6fTyfDqpUtx/vnPXq/s2gI0w3Qn34bpdrEAuLqc58uqV6fac8/BH/8Y5J8oOpSVlfHCCy8wefJk3n//fdq1a2d3SCGlxBci8+bN47777mPu3Ll07NjR7nBEbJOXl8fKlStPJ8PZmzfTvAJvQ99hEt8rQD9vD7ZoYdb6pFzz5s3j3nvv5eWXX2bgwIF2hxMySnwhMGfOHIYOHcqCBQtIS0uzOxyR8OHxYLlcOEpKfD76MDAdKMK0evoMOEfjn58lJMCJExG7ozZUtmzZQq9evbj11lsZN25cTFx0HpnXEkSQd955h+HDh/PJJ58o6Yn80s6dOBIS/Hp0EnAMWAH0xkx7ehUfb46ViFdXXnklX375Jd9++y0333wz+fn5dodU5ZT4qtDf//53nnjiCZYsWUKrVq3sDkck/LjdFWrI6gQ6A3uA13w+7Axu+6goduGFF/Lxxx/Ttm1b2rVrx6ZNm+wOqUpFf01rk1dffZWXXnqJTz/9lEsvvdTucETCU2Jipc7teYDtvh4qLTXji1+cTicvvfQSrVq14oYbbmDy5Mn07t3bvxdH2KUNWuOrAi+//DKTJk1i6dKlNG7c2O5wRMKXxwM1apgD7OU4CCwFbgWSgGzMVOcs4LfextYaX6V99dVX/O53v2PQoEE8++yz576sO4SXiAebEl+QPf/888yYMYOlS5eSnJxsdzgi4a9FC9i6tdw/PgT0Bb4GyoBGwHDgPn/G1a7OSjtw4AB9+vTh17/+NW+//fbPrZkq0jk+KclcJz56tOkc7+d6blXTGl95PB5z2fDmzeazx+P1ccuyePrpp5k1axbLly9X0hPxV0aGeYMsx0XAcuAH4CiwCT+SXlKSGVcqrW7duixdupR69erRoUMHcnJyzH2qqanmlh2323vSA/Pnbrd5PjXVvD4MqOI7UyVLd8uyeOyxx1iyZAmLFy/moosusvGHEIkwBQWms0IwN6IkJpoOD2E2xRapXn/9daY8/TRfeDxUP348bJsv+0uJDwIq3cvi4hg2bBhr165l4cKF1NJ/aCIVF8S7OgsdDhZecQXpy5frl9BgKS6msGlTEvbsCWxH5DkuEbeDpjoDKN2tK6/kibvuYuPGjSxevFhJT6Syxo41VV+gd2o6nSQ2bcrqm26iZcuWzJkzJzjxxbrMTFwFBYEfAygtNZV4ZmYwoqq02K74Auz/VepwcDQujoSvv6ZG8+ZVEKBIDAlyP77Vq1eTkZFBamoqr776auXaJYlfl4j/DXOrziag/0//7JXNU9GxW/EVF5tOz5X9jwxwWhY1y8qocfvtpsu2iFReSopJWk2amI7qFeFymdedsX507bXXsmHDBpo1a0bLli2ZNWsWsfx7fqVNnWqWgLxoADwNDPZ3TIcD3nwzwMAqL3YTX2am+Y0jkKaXgCNMSneRqJCSYnZSjxxpqgIvuz0Bk/ASE2HUKLNu9ItNE4mJibz44ot89NFHvPDCC/Tu3Zv9+/dXXfzRaNo0n0tAvYFewK/9HbOoCKZPDyisQMRm4svPNxtZvCyk7wRuAS7E9P8airkt4pwKC+EvfzHVo4gEJiEBXnjh518oW7QwX6tRAy64wHxOSDBfP/UL7PPPe90scc011/DVV1+RmppKq1atePvtt1X9+cPjMTvcq4Ifx8SqSmyu8Y0bZxbTvfwWcwtQB3gdc37oRszZoeHlvSApyfxHqP5fIsEXxCux1q9fT0ZGBikpKbz++us0bNgwqKFGlZwcaN3a3IDjh6cx96hO9+fhGjVg40ZzTCzEYrPi86N0/x7T6ysRU/HdjGmGWS6bS3eRqBYXZ94gU1PN5wCuIWvbti1r164lLS2NNm3aMG3aNFV/5angJeIVYuMl4rGX+Pws3UcC/wsUAnnAPzHJzysbS3cR8V9CQgJjx45l8eLFTJw4kZ49e7J79267wwo/lbxE3C82XiIee4lv506/Dk52wVR4FwDJQBpm8dYr9f8SiSitWrVizZo1dO7cmbZt2/LGG2+o+jtT48Z+7Vj3AG6g9KcPN172RJxSUmLGt0HsJT4/SvcyTHXXGzgBHAaOAI/7Glv9v0QiTnx8PE8//TSffvopU6ZMoUePHuzUL7BGXBw0berzsecxnTP+G3jnp39+3teLApyyDkTsJT4/SvcCIBezk7M6ZotuBvCxr7HV/0skYqWmpvLFF1/QrVs30tLSeO211ygrK7M7LPv5uEQc4FnA+sXHs95eYPMl4rG3q9OP/l8ATYD7gceA45jElwTM9PYi9f8SiQpbt24lIyODpKQk3nzzTZo0aWJ3SPaJwkvEY6/i87N0fx9YiGmJ0gyIB/7q60U2lu4iEjzNmzfn888/5ze/+Q3t2rVj4sSJsVv91arF/rvuotDH7S1+c7ng0Udt7ZwRe4kP/CrdWwPLMGt7h4H3gLreXqD+XyJRxel08thjj7Fq1Spmz55N165d+fe//213WCH3zjvv0HbePErr1g3KJeI0bGjOUdsoNhPfkCGmtVAwWRYM9vumOhGJEJdddhnLly+nT58+dOjQgb/+9a+UVtUW/zBSUlLCiBEjyMzM5JNlyzh/zRpTpVU2+Z26RDw729aWRBCria9WLdNPr6IX4ZYnDEp3Eak6TqeTESNGsHr1aubOnUt6ejrfffed3WFVmf3799OtWze2b9/O2rVrSU1NDfol4naKzcQHp/t/lVUL8K8gTEp3Eal6zZo149NPP+Xuu++mc+fOvPTSS3ii7NKKL774gmuuuYYbbriBDz/8kJo1a/78hxW8RNzycYm4XWJvV+cZFr3xBlc/8AC1HA6qVWbh+hf9v0Qkdnz//ffce++9HDt2jGnTpnHllVfaHVJALMtiypQp/OlPf2Lq1Knceuut3l9QUGBaC02fbm6tio8374mlpVglJfy7rIwLRo6k3pgxYTcbFrMV35IlS7jnqafY+9FHVGvaNOJLdxEJrUsuuYTs7GyGDBlC165d+a//+q+Irf7cbjf33nsvEydO5PPPP/ed9MAksz/+0VRyJ06YC6c//xw2bsRx4gTjhwxhZv36YZf0IEYT35o1a+jfvz//93//R8vf/Cao/b9EJHY4HA4eeOAB1q1bx7Jly2jfvj3ffPON3WFVSG5uLunp6Rw/fpzVq1dz6aWXVnyQc1wifvPNN7Nw4cLgBxwMVozZtGmTVbduXWv+/Pn/+Yf5+Zb10kuW1aKFZSUkWFaNGpZ1wQXmc0KC+fq4ceY5EZEzlJWVWX//+9+t2rVrW5mZmVZxcbHdIfm0dOlSq169eta4ceOssrKyoI599OhR67zzzrOOHz8e1HGDIabW+Hbs2EGXLl0YN24c/fv39/5wEPt/iUjs2LNnD/fffz979+5l+vTptG7d2u6Q/oNlWYwfP55x48bx7rvv0q1btyr5Ptdffz2PPfYYv/nNb6pk/MqKmanOvXv3cuONN/L000/7TnoQ1P5fIhI7kpOTWbBgAaNGjaJHjx4888wzFPu4IjGUTpw4Qf/+/Zk1axZr1qypsqQHhO10Z0wkvvz8fHr06MG9997Lgw8+aHc4IhLlHA4HAwcOZOPGjWzcuJG0tDS++uoru8MiJyeHa6+9lqSkJFasWEGjRo2q9Psp8dnk2LFj9OzZk1tuuYUnnnjC7nBEJIY0aNCAefPm8fjjj3PLLbfw5JNPcvLkycoP6PGYowObN1e48fWCBQvo2LEjDz/8MFOnTiXJ10a+IGjZsiUnTpwgJyenyr9XRUR14nO73dx+++20bt2a//mf/8ERrEtWRUT85HA4uPvuu/n666/Ztm0bbdq0Yc2aNf4PkJ8P48ZBixams0zr1tCpk/nscpmvjxtnztWdQ1lZGZmZmTzwwAPMnTuXhx56KGTvhQ6HIyyrvqjd3OLxeOjbty/Vq1dn5syZOAO9XFVEJECWZfHee+8xYsQIBgwYQGZmZvmVV3ExZGbC+PHgcEBRUfkDJyWZ+4JHjza3SCUkAPDDDz9wzz33cOTIEebMmUP9+vWr4Kfy7r333mPGjBnMnz8/5N+7PFGZ+MrKyhg0aBCHDh1i3rx5JPz0L4GISDg4ePAgQ4cO5euvv2batGl07Njx7Adyc6F7d9OzrrDQ/4FdLnOFYnY2m48epXfv3tx000385S9/se198MiRIzRq1IiDBw+SGCaNuqMu8VmWxYgRI1i/fj2ffPIJrmBdRC0iEmT/+Mc/GDp0KP379+f5558371e5uZCWZqYuK9MFwunEXaMG1zqdjJ4wgQEDBgQ/8Arq1KkTzz77LDfeeKPdoQBRuMb37LPPsmLFCubPn6+kJyJhrU+fPmzatIkDBw7QqlUrVixZYiq9yiY9gNJS4o4eZfX55zPAn6NbIRBu63xRVfFNmDCBSZMmsWLFCurW9do2VkQkrMybN4+dd9/NgydPUj0Yd366XOZqxeefD3ysAK1du5ZBgwaxZcsWu0MBoijxTZs27XS1l6L7M0Uk0uTnYyUn43C7vT72b+AqoC/wjq8xExPNOqHNF0WXlZVRr1491q1bFxbvz1Ex1fn+++/z5JNPsmjRorD4SxURqbCpU/06ZvAIcI2/YzocpnWQzapVq0aPHj3CZroz4hNfdnY2Dz74IAsWLOCKK66wOxwRkcqZNs37kQXgf4GagN+XjBUVmX55YSCc1vkieqpz9erV3Hbbbbz//vukp6fbHY6ISOV4POZwupc7PY8CacBS4O9ADn5MdYI503fihO33DR86dIhLL72UQ4cOER8fb2ssEVvxbdq0idtvv5233npLSU9EItvOnaaDuRd/AoYAyRUdOz7ejG+ziy66iEsvvZRVq1bZHUpkJr6cnBxuvvlmsrKyuOWWW+wOR0QkMG43eLldaiOQDYyqzNhOpxk/DITLdGfEJb68vDx69OjB2LFjueOOO+wOR0QkcImJXs/tLQN2AilAPeBl4B9AW3/GLi0144eBcEl8EbXGl5+fT5cuXRgwYACPP/643eGIiASHjzW+Qswa3ykvYxLha8BFvsYOkzU+MHco16lThy1btthyb+gpEVPxnWovdNtttynpiUh0iYuDpk3L/WMXptI79XEekIgfSQ/CqpF2XFwc3bt3Z9GiRbbGERGJ71R7obZt2/Liiy/aHY6ISPBlZJguC354Fv92dJbEx1N4552BRBV04TDdGfZTnSUlJfTt25ekpCTeffddtRcSkehUUGA6KwRxI0qx08nlNWpw8113MWzYMFq0aBG0sSsrLy+Pli1bcvDgQdvez8O64isrK2Pw4MGUlJQwY8YMJT0RiV61apl+esG6XN/lIuGJJ1i1bRt169blhhtuoEePHixYsICysrLgfI9KaNiwIcnJyXz55Ze2xRC2FZ9lWQwfPpyvv/6ahQsXqtOCiES/4mJITYUdOyrfnQHMEYYmTWDLltPnA0+ePMns2bN55ZVXOHr0KMOGDWPQoEFccMEFQQref48//jiJiYlkZmaG/HtDGFd8Y8eO5fPPP+ejjz5S0hOR2JCQANnZpvqr7AyX02len5191qH46tWrM2DAANatW8f06dNZuXIljRs3ZuTIkeTk5ATpB/CP3et8YZn4/vrXvzJ79mwWLlzIr371K7vDEREJnZQUWLfOVGwV/aXf5TKvW7fOjHMODoeDTp068d577/H111+TlJREhw4duO2228jOziYUk4CdOnVi27ZtHD58uMq/17mEJvF5PJCTA5s3m89eek1NnTqVCRMmsHjxYurUqROS8EREwkpKinm/HDnSHD73tdvT5TLPjRplpjf97FJz8cUX8+KLL7Jr1y5++9vfMmrUKFJTU5k8eTKFhYWB/xzlSEhI4Prrr+eTTz6psu/hTdWt8eXnw9Sp5sbx7dtNye10mnnr4mJztiQjA4YMOd0r6h//+AfDhg1j2bJlXHbZZVUSlohIRCkoMK2Fpk83hcOZ76UlJT+/lw4eHHDfPcuy+PTTT8nKymLlypUMHjyYRx55hEaNGgXnZznD66+/zqpVq5gxY0bQx/Yl+ImvuBgyM2H8eNMLylubjaQksCwYPZrFHTtyd0YGixYtok2bNkENSUQkKng85sJpt9tUeI0bV9nh9B07dvC3v/2Nt956i65duzJixAjS09P96hnoj507d9K+fXv27dtHtbKykP1cEOzEl5sL3bubjr8VKJNLExPZWVxMwXvvcU2fPkELR0REAnP8+HHeeustsrKycLlcjBgxgjvvvJPEQO//zM9nXPPmDD3vPJLy8vyaFQyW4CW+3FxISzNleSW24ZZVq0a1X//a66KsiIjYo6ysjEWLFpGVlcX69eu5//77eeihh2jQoEHFBjpjVrDY4yHBy56PM2cFGTvW7HoNguAkvio8eyIiIuFl27ZtTJw4kZkzZ9KzZ09GjBhB+/btfb+wkrOCuFzmVpvs7KAURsHZ1ZmZaX6QQJIemNfn5ZnxREQkLF1xxRW8+uqrfP/996SlpdG/f3+uvfZaZs2aRXF5XeRPzQru2FGxpAfm+R07zOtzcwOOP/CKLz8fkpO93i+3FXgE+Apzm/g44HfexkxMNAkwyPO6IiISfKWlpcyfP59XXnmF7777joceeoj777//5yNpYTYrGHjFN3Wq2b1ZDg9wO3ArUABMAf4A/MvbmA6H2b4rIiJhz+l0cvvtt7N06VIWLlzIrl27uPzyy8nIyGDjxo1hNysYeMXXogVs3VruH28GrgWOAafSYw+gPfCcr3G3bAkoNBERscfhw4d54403mDlxIusOHKC6l4uxuwKrgVMHGBoC33kbPMBZwcAqPo/HHE6vIAuTEL3yccOLiIiEr9q1azNmzBg2Dh+O048zeX8Djv/04TXpQcCzgoElvp07fc6zXg7UwazrlQCfAMsBn0ub8fFmfBERiVjOGTOIK2/DS2UVFZmbbCopsMTndvu8QTwemAssAOoBfwH6Acm+xnY6g9qQUUREQqwCs4JjgNpAJ2CZPy8IYFYwsMSXmOjXYmVLTJWXDywCdgDtfL2otNSMLyIikcmPWUGA/8HkhTzgfuA2wGe6DGBWMLDE17ixuSTVh28AN2Z682VgHzDI14tKSsz4IiISmfyYFQSz2fF8oDowEFP1fezrRQHMCgaW+OLioGlTn4+9DdTHrPUtARZjfkCvmjWr0ktKRUSkivk5K/hLDswmSK8CmBUM/BxfRobPXlHjgCOY3Tr/BJr5GjMpyYwrIiKRy49ZwR8wS2BuzLnvd4HPgJt9jR3ArGDg5/gKCswdasHciKKbW0REooOPs96HgFuAbYATuAJzxvtGf8at5FnvwCu+WrXMzdkuV8BDAWacRx9V0hMRiQY+ZgUvAtZiLjn5AXOQ3WfSC3BWUN0ZRESk6oThrGBwujMkJJh2EbVq+bWD55ycTvP67GwlPRGRaBGGs4Jh0YE92L2WREQkjITZrGBwKr5TUlJg82YYOdKUoj52e+JymedGjTI/iJKeiEj0CbNZweBWfGcqKDCXiE6fbq6WiY83gZeWmm2ozZqZxcnBg7WRRUQkFoTJrGDVJb4zeTzmahm321R4jRvrcLqISCwqLjb99MaPN10WiorKf9blgrIys6Y3dmzQ9n+EJvGJiIicycZZQSU+ERGxV4hnBZX4REQkpgR3V6eIiEiYU+ITEZGYosQnIiIxRYlPRERiihKfiIjEFCU+ERGJKf8fMt6W2/yU2n8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "nx.draw(graph, with_labels=True, node_color=\"r\", node_size=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./4.jpg\" width=400 height=250 >"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./5.jpg\" width=400 height=250 >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from node2vec import Node2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimensions = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d491fc4f1394a31ac3f7dc10f329dec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Computing transition probabilities', max=11.0, style=Prog…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating walks (CPU: 1): 100%|███████████████████████████| 10/10 [00:00<00:00, 313.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "node2vec = Node2Vec(graph, dimensions=dimensions, walk_length=6, num_walks=10) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = node2vec.fit(window=10, min_count=1, batch_words=4) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_maxtrix = model.wv.vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx2word = model.wv.index2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('7', 0.5307246446609497),\n",
       " ('8', 0.1598607748746872),\n",
       " ('1', 0.12193160504102707),\n",
       " ('6', 0.12057728320360184),\n",
       " ('5', 0.0355752557516098),\n",
       " ('9', 0.019031256437301636),\n",
       " ('3', -0.07893448323011398),\n",
       " ('2', -0.08174138516187668),\n",
       " ('4', -0.12552376091480255),\n",
       " ('B', -0.13893216848373413)]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.similar_by_word(\"A\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('8', 1.0),\n",
       " ('6', 0.1768108755350113),\n",
       " ('A', 0.1598607748746872),\n",
       " ('7', 0.10418733954429626),\n",
       " ('4', 0.08851936459541321),\n",
       " ('2', -0.01717028021812439),\n",
       " ('9', -0.02063833177089691),\n",
       " ('5', -0.029865790158510208),\n",
       " ('1', -0.04628392681479454),\n",
       " ('3', -0.10360722243785858)]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.similar_by_vector(embedding_maxtrix[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "index2word = model.wv.index2word\n",
    "idx2word = dict([(idx, item) for idx, item in enumerate(index2word)])\n",
    "word2idx = dict([(item, idx) for idx, item in enumerate(index2word)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '7',\n",
       " 1: '3',\n",
       " 2: '4',\n",
       " 3: '1',\n",
       " 4: '6',\n",
       " 5: '8',\n",
       " 6: 'A',\n",
       " 7: '2',\n",
       " 8: '5',\n",
       " 9: '9',\n",
       " 10: 'B'}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('1', '2'),\n",
       " ('1', '4'),\n",
       " ('1', '5'),\n",
       " ('2', '3'),\n",
       " ('3', '4'),\n",
       " ('4', '5'),\n",
       " ('3', '6'),\n",
       " ('6', '7'),\n",
       " ('7', '8'),\n",
       " ('8', '9'),\n",
       " ('7', 'A'),\n",
       " ('A', 'B')]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_maxtrix = torch.from_numpy(embedding_maxtrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 3.4896e-04, -5.8526e-03,  5.1271e-03,  5.1128e-03,  1.3546e-02,\n",
       "          8.1610e-03,  7.4118e-03, -1.1965e-02, -2.0133e-03,  1.4264e-02,\n",
       "          1.1182e-02,  1.0829e-02,  2.6168e-03, -1.3432e-02, -1.4004e-02,\n",
       "          4.7380e-03,  3.5851e-03,  1.0873e-03,  3.0503e-03,  1.0967e-02,\n",
       "         -5.4748e-03,  1.2567e-02, -5.9085e-03, -1.0333e-02,  1.5333e-02,\n",
       "         -1.4712e-02,  4.7506e-04,  1.0400e-02, -1.3373e-02,  1.4804e-02,\n",
       "         -9.9758e-03,  7.7217e-03],\n",
       "        [ 3.4988e-03, -1.0074e-02, -6.8107e-03,  1.5451e-02, -1.3742e-02,\n",
       "         -6.6256e-03,  3.3363e-03, -7.6711e-03,  1.1953e-02,  8.2647e-03,\n",
       "          8.6710e-04, -6.4642e-03,  9.0327e-03, -2.5512e-03, -5.2309e-03,\n",
       "         -1.1418e-03, -7.9369e-03,  8.5218e-03, -2.6970e-04,  1.3964e-02,\n",
       "          7.0266e-03, -1.2049e-02, -1.9737e-04, -5.1688e-03, -1.1197e-02,\n",
       "          1.0655e-03,  4.3608e-03, -2.5867e-03,  1.4019e-02, -7.0888e-03,\n",
       "         -1.0104e-02,  1.0570e-02],\n",
       "        [ 3.9165e-03, -2.4258e-04,  1.3472e-02, -9.4294e-03, -1.5610e-02,\n",
       "         -6.3373e-03,  2.3857e-03, -3.0188e-03,  1.5095e-02, -1.1014e-02,\n",
       "          6.0114e-03, -1.2862e-03,  1.3179e-02, -4.6197e-03, -9.8277e-04,\n",
       "          1.9380e-03,  2.1219e-03,  8.0261e-03, -6.2649e-03,  2.8185e-03,\n",
       "          1.5399e-02, -1.4409e-02,  1.1230e-02,  7.8588e-03, -1.1106e-02,\n",
       "          4.4224e-03, -1.4367e-02, -2.8072e-03,  2.5892e-03,  6.5484e-03,\n",
       "         -9.4190e-03,  1.7591e-04],\n",
       "        [ 6.8456e-03, -6.4135e-03, -1.6970e-03,  1.4612e-02, -2.4418e-03,\n",
       "          6.7923e-03,  1.5232e-02,  1.1432e-02, -1.3261e-02,  1.0656e-02,\n",
       "         -3.0676e-03, -1.5139e-02, -7.5009e-03,  8.8843e-03, -4.8247e-03,\n",
       "         -1.2468e-02, -3.7537e-03,  6.0721e-03,  4.3988e-03,  7.6107e-03,\n",
       "          1.5458e-02, -4.3371e-03,  5.4170e-03, -3.8030e-04,  5.4950e-03,\n",
       "         -1.4678e-02,  1.3501e-02, -1.1894e-02, -3.0247e-03,  7.4699e-03,\n",
       "         -1.4326e-02,  2.6145e-03],\n",
       "        [ 1.3690e-02,  5.0739e-03,  2.8156e-03, -1.0170e-02,  5.8901e-03,\n",
       "         -6.1460e-03,  5.9449e-03,  5.1825e-03,  9.6134e-03, -2.8758e-03,\n",
       "          1.1496e-02,  8.2836e-03, -2.0361e-03,  4.1155e-03, -1.2683e-02,\n",
       "         -5.4905e-03,  4.7308e-03, -1.2789e-02,  9.6537e-03, -4.0864e-03,\n",
       "         -9.6921e-03,  1.1172e-02, -4.5422e-03,  1.5309e-02, -6.4882e-03,\n",
       "         -1.1947e-02,  3.9968e-03, -1.3160e-02,  1.1113e-03,  1.4567e-02,\n",
       "         -1.2326e-02,  1.2424e-02],\n",
       "        [-8.0521e-03, -1.0185e-03,  7.0341e-03,  4.4096e-03,  3.8887e-03,\n",
       "          2.5127e-03,  1.2824e-02,  4.4275e-03, -4.1744e-03, -1.2832e-02,\n",
       "         -2.8691e-03,  1.2550e-02,  1.7666e-03,  3.7734e-03, -1.3608e-02,\n",
       "          6.0036e-03, -9.4003e-03, -9.9884e-03,  8.2100e-03, -9.4134e-03,\n",
       "          8.1466e-03, -6.0616e-03, -1.5018e-02, -2.6429e-03, -8.9792e-03,\n",
       "          3.0221e-03, -6.7445e-03,  6.9703e-03,  1.0152e-03,  1.3305e-02,\n",
       "         -3.1632e-03, -9.6872e-05],\n",
       "        [ 8.8537e-03,  9.4500e-03, -8.8276e-03, -2.0774e-03,  2.5956e-04,\n",
       "         -1.7954e-03,  4.6202e-03, -7.9148e-03, -1.1146e-02,  4.8531e-03,\n",
       "          9.6314e-03,  9.8110e-03, -2.3519e-03, -1.3821e-02, -1.6693e-02,\n",
       "          7.2338e-03,  2.6532e-03,  9.4011e-03,  1.0452e-02,  1.0451e-02,\n",
       "          1.5003e-02,  1.0493e-02, -1.5169e-02, -1.2479e-03,  1.1656e-02,\n",
       "         -5.8563e-03,  1.3645e-03,  2.1231e-03, -1.2546e-03,  1.1830e-02,\n",
       "          9.5850e-04, -1.1777e-02],\n",
       "        [ 9.0698e-05,  9.8819e-04,  5.7771e-03,  4.1786e-03, -4.2915e-03,\n",
       "         -6.8246e-03, -7.1190e-03,  3.9776e-03,  3.5162e-03,  7.0637e-03,\n",
       "         -5.0037e-03, -7.6001e-03,  2.7641e-03, -1.0808e-03, -5.7283e-03,\n",
       "          1.3881e-02,  6.3667e-03, -2.1065e-03,  1.3436e-02, -5.5598e-03,\n",
       "         -1.2568e-02,  1.2302e-02, -8.6117e-03, -1.2879e-02, -1.0227e-02,\n",
       "         -1.0419e-02, -6.1313e-03,  5.0862e-04,  2.4407e-03, -1.1272e-02,\n",
       "          8.2592e-03,  1.0876e-02],\n",
       "        [ 6.8421e-03,  2.4219e-03, -1.5013e-02,  1.3238e-03,  4.0118e-03,\n",
       "          1.2933e-03,  1.2022e-02,  1.1386e-02,  1.0250e-02,  6.2729e-03,\n",
       "         -2.0617e-03,  7.9355e-03,  3.5249e-03, -2.3488e-03,  7.6987e-03,\n",
       "         -4.7044e-03,  7.0346e-03,  8.2724e-03,  4.8487e-04, -7.1788e-03,\n",
       "          9.7290e-03,  6.9427e-03,  5.0785e-03, -1.2727e-02, -1.0940e-02,\n",
       "          1.1232e-02,  2.8563e-03, -1.5448e-02, -1.4499e-02,  8.5198e-03,\n",
       "         -1.3950e-03,  8.0126e-04],\n",
       "        [ 1.2651e-02, -8.7522e-03, -8.6300e-03, -1.2453e-03, -9.4375e-03,\n",
       "          1.0634e-02,  4.3302e-03, -5.4497e-03,  1.1193e-02,  5.3637e-03,\n",
       "         -1.3915e-02,  1.5495e-02,  8.1927e-03, -6.8280e-03, -1.4587e-02,\n",
       "         -1.4462e-02,  1.6192e-03,  7.9035e-03, -1.1989e-02, -1.0610e-02,\n",
       "          1.3760e-02, -2.9220e-03,  2.6055e-03,  3.2073e-03, -5.0829e-03,\n",
       "          1.7453e-03, -1.2849e-02, -1.2931e-02,  7.6090e-03, -1.3705e-02,\n",
       "         -1.9691e-03, -6.9576e-03],\n",
       "        [ 1.1567e-02, -8.5426e-03, -3.2666e-03, -1.9524e-03, -1.4312e-02,\n",
       "          8.9865e-04,  5.4321e-03, -4.8936e-04, -1.3122e-02,  2.7656e-03,\n",
       "         -3.4372e-04, -1.3483e-02,  5.7055e-04, -5.4053e-04,  8.2457e-03,\n",
       "          2.7655e-03, -4.9811e-03,  1.1766e-02, -1.1591e-02, -1.0142e-02,\n",
       "          4.8313e-03, -1.3327e-02,  2.1713e-03, -6.8277e-03, -7.8431e-03,\n",
       "         -1.1514e-02,  4.0690e-03, -1.4041e-02, -1.3033e-02, -2.1548e-03,\n",
       "          7.2052e-03,  3.9943e-03]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_maxtrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0003, -0.0059,  0.0051,  0.0051,  0.0135,  0.0082,  0.0074, -0.0120,\n",
       "         -0.0020,  0.0143,  0.0112,  0.0108,  0.0026, -0.0134, -0.0140,  0.0047,\n",
       "          0.0036,  0.0011,  0.0031,  0.0110, -0.0055,  0.0126, -0.0059, -0.0103,\n",
       "          0.0153, -0.0147,  0.0005,  0.0104, -0.0134,  0.0148, -0.0100,  0.0077]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Embedding.from_pretrained(embedding_maxtrix)(torch.Tensor([0]).long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinkPrediction(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.emb = nn.Embedding.from_pretrained(embedding_maxtrix)\n",
    "        self.emb.require_grad = False\n",
    "        \n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(dimensions * 2, dimensions),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(dimensions, 2),\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.emb(x)\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_edges = [[edge[0], edge[1], 0] for edge in edges]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['1', '2', 0],\n",
       " ['1', '4', 0],\n",
       " ['1', '5', 0],\n",
       " ['2', '3', 0],\n",
       " ['3', '4', 0],\n",
       " ['4', '5', 0],\n",
       " ['3', '6', 0],\n",
       " ['6', '7', 0],\n",
       " ['7', '8', 0],\n",
       " ['8', '9', 0],\n",
       " ['7', 'A', 0],\n",
       " ['A', 'B', 0]]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "real_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "fake_edges = [('1', '7', 1), ('1', '8', 1), ('1', '9', 1), ('2', '8', 1), ('2', '9', 1), \n",
    "             ('9', '5', 1), ('9', '4', 1), ('7', '4', 1), ('7', '5', 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = real_edges + fake_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset += [[item[1], item[0], item[2]]for item in dataset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "lp = LinkPrediction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(lp.parameters(), lr=2e-3)\n",
    "loss_function = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1190,  0.0920]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lp(torch.Tensor([[1, 2]]).long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'7': 0,\n",
       " '3': 1,\n",
       " '4': 2,\n",
       " '1': 3,\n",
       " '6': 4,\n",
       " '8': 5,\n",
       " 'A': 6,\n",
       " '2': 7,\n",
       " '5': 8,\n",
       " '9': 9,\n",
       " 'B': 10}"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '7',\n",
       " 1: '3',\n",
       " 2: '4',\n",
       " 3: '1',\n",
       " 4: '6',\n",
       " 5: '8',\n",
       " 6: 'A',\n",
       " 7: '2',\n",
       " 8: '5',\n",
       " 9: '9',\n",
       " 10: 'B'}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.shuffle(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['8', '1', 1],\n",
       " ['5', '9', 1],\n",
       " ['4', '3', 0],\n",
       " ['1', '4', 0],\n",
       " ['4', '1', 0],\n",
       " ['3', '6', 0],\n",
       " ['1', '5', 0],\n",
       " ('7', '4', 1),\n",
       " ['9', '1', 1],\n",
       " ('1', '7', 1),\n",
       " ('9', '4', 1),\n",
       " ['9', '2', 1],\n",
       " ['7', '8', 0],\n",
       " ('7', '5', 1),\n",
       " ['A', 'B', 0],\n",
       " ['5', '7', 1],\n",
       " ['6', '7', 0],\n",
       " ['7', 'A', 0],\n",
       " ['9', '8', 0],\n",
       " ['5', '4', 0],\n",
       " ['8', '7', 0],\n",
       " ['6', '3', 0],\n",
       " ['8', '9', 0],\n",
       " ['5', '1', 0],\n",
       " ['7', '6', 0],\n",
       " ['A', '7', 0],\n",
       " ('2', '8', 1),\n",
       " ['8', '2', 1],\n",
       " ['B', 'A', 0],\n",
       " ['4', '9', 1],\n",
       " ('1', '9', 1),\n",
       " ('9', '5', 1),\n",
       " ['3', '4', 0],\n",
       " ['2', '3', 0],\n",
       " ['7', '1', 1],\n",
       " ['4', '5', 0],\n",
       " ('1', '8', 1),\n",
       " ['1', '2', 0],\n",
       " ['3', '2', 0],\n",
       " ['2', '1', 0],\n",
       " ['4', '7', 1],\n",
       " ('2', '9', 1)]"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = []\n",
    "for idx, item in enumerate(dataset):\n",
    "    ipt = [word2idx[item[0]], word2idx[item[1]],item[2]]\n",
    "    trainset.append(ipt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[5, 3, 1],\n",
       " [8, 9, 1],\n",
       " [2, 1, 0],\n",
       " [3, 2, 0],\n",
       " [2, 3, 0],\n",
       " [1, 4, 0],\n",
       " [3, 8, 0],\n",
       " [0, 2, 1],\n",
       " [9, 3, 1],\n",
       " [3, 0, 1],\n",
       " [9, 2, 1],\n",
       " [9, 7, 1],\n",
       " [0, 5, 0],\n",
       " [0, 8, 1],\n",
       " [6, 10, 0],\n",
       " [8, 0, 1],\n",
       " [4, 0, 0],\n",
       " [0, 6, 0],\n",
       " [9, 5, 0],\n",
       " [8, 2, 0],\n",
       " [5, 0, 0],\n",
       " [4, 1, 0],\n",
       " [5, 9, 0],\n",
       " [8, 3, 0],\n",
       " [0, 4, 0],\n",
       " [6, 0, 0],\n",
       " [7, 5, 1],\n",
       " [5, 7, 1],\n",
       " [10, 6, 0],\n",
       " [2, 9, 1],\n",
       " [3, 9, 1],\n",
       " [9, 8, 1],\n",
       " [1, 2, 0],\n",
       " [7, 1, 0],\n",
       " [0, 3, 1],\n",
       " [2, 8, 0],\n",
       " [3, 5, 1],\n",
       " [3, 7, 0],\n",
       " [1, 7, 0],\n",
       " [7, 3, 0],\n",
       " [2, 0, 1],\n",
       " [7, 9, 1]]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = torch.from_numpy(np.array(trainset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 5,  3,  1],\n",
       "        [ 8,  9,  1],\n",
       "        [ 2,  1,  0],\n",
       "        [ 3,  2,  0],\n",
       "        [ 2,  3,  0],\n",
       "        [ 1,  4,  0],\n",
       "        [ 3,  8,  0],\n",
       "        [ 0,  2,  1],\n",
       "        [ 9,  3,  1],\n",
       "        [ 3,  0,  1],\n",
       "        [ 9,  2,  1],\n",
       "        [ 9,  7,  1],\n",
       "        [ 0,  5,  0],\n",
       "        [ 0,  8,  1],\n",
       "        [ 6, 10,  0],\n",
       "        [ 8,  0,  1],\n",
       "        [ 4,  0,  0],\n",
       "        [ 0,  6,  0],\n",
       "        [ 9,  5,  0],\n",
       "        [ 8,  2,  0],\n",
       "        [ 5,  0,  0],\n",
       "        [ 4,  1,  0],\n",
       "        [ 5,  9,  0],\n",
       "        [ 8,  3,  0],\n",
       "        [ 0,  4,  0],\n",
       "        [ 6,  0,  0],\n",
       "        [ 7,  5,  1],\n",
       "        [ 5,  7,  1],\n",
       "        [10,  6,  0],\n",
       "        [ 2,  9,  1],\n",
       "        [ 3,  9,  1],\n",
       "        [ 9,  8,  1],\n",
       "        [ 1,  2,  0],\n",
       "        [ 7,  1,  0],\n",
       "        [ 0,  3,  1],\n",
       "        [ 2,  8,  0],\n",
       "        [ 3,  5,  1],\n",
       "        [ 3,  7,  0],\n",
       "        [ 1,  7,  0],\n",
       "        [ 7,  3,  0],\n",
       "        [ 2,  0,  1],\n",
       "        [ 7,  9,  1]], dtype=torch.int32)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = DataLoader(trainset, batch_size=17, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[3, 7, 0],\n",
       "        [5, 0, 0],\n",
       "        [8, 3, 0],\n",
       "        [6, 0, 0],\n",
       "        [2, 9, 1],\n",
       "        [9, 3, 1],\n",
       "        [7, 1, 0],\n",
       "        [7, 3, 0],\n",
       "        [0, 4, 0],\n",
       "        [9, 7, 1],\n",
       "        [0, 8, 1],\n",
       "        [5, 9, 0],\n",
       "        [3, 0, 1],\n",
       "        [3, 5, 1],\n",
       "        [1, 4, 0],\n",
       "        [2, 0, 1],\n",
       "        [9, 2, 1]], dtype=torch.int32)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(trainset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def acc(out, tar):\n",
    "    length = len(out)\n",
    "    out = out.argmax(dim=-1)\n",
    "    res = out == tar\n",
    "    return res.sum().item() / length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc(out, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch[100/5000] step[1/3] loss = 0.650657057762146 acc = 0.7647\n",
      "Epoch[200/5000] step[1/3] loss = 0.5806478261947632 acc = 0.7647\n",
      "Epoch[300/5000] step[1/3] loss = 0.48537999391555786 acc = 0.7647\n",
      "Epoch[400/5000] step[1/3] loss = 0.4118763506412506 acc = 0.8235\n",
      "Epoch[500/5000] step[1/3] loss = 0.42787036299705505 acc = 0.7647\n",
      "Epoch[600/5000] step[1/3] loss = 0.4109613597393036 acc = 0.7059\n",
      "Epoch[700/5000] step[1/3] loss = 0.2076452374458313 acc = 0.9412\n",
      "Epoch[800/5000] step[1/3] loss = 0.18956951797008514 acc = 0.8824\n",
      "Epoch[900/5000] step[1/3] loss = 0.2377050668001175 acc = 0.9412\n",
      "Epoch[1000/5000] step[1/3] loss = 0.0944388136267662 acc = 0.9412\n",
      "Epoch[1100/5000] step[1/3] loss = 0.15609750151634216 acc = 1.0\n",
      "Epoch[1200/5000] step[1/3] loss = 0.12618091702461243 acc = 1.0\n",
      "Epoch[1300/5000] step[1/3] loss = 0.05262225493788719 acc = 1.0\n",
      "Epoch[1400/5000] step[1/3] loss = 0.10535813868045807 acc = 1.0\n",
      "Epoch[1500/5000] step[1/3] loss = 0.17256049811840057 acc = 0.8824\n",
      "Epoch[1600/5000] step[1/3] loss = 0.25195586681365967 acc = 0.8824\n",
      "Epoch[1700/5000] step[1/3] loss = 0.12540486454963684 acc = 0.9412\n",
      "Epoch[1800/5000] step[1/3] loss = 0.0594380758702755 acc = 1.0\n",
      "Epoch[1900/5000] step[1/3] loss = 0.051771730184555054 acc = 1.0\n",
      "Epoch[2000/5000] step[1/3] loss = 0.15031315386295319 acc = 0.9412\n",
      "Epoch[2100/5000] step[1/3] loss = 0.054263800382614136 acc = 1.0\n",
      "Epoch[2200/5000] step[1/3] loss = 0.0424722358584404 acc = 1.0\n",
      "Epoch[2300/5000] step[1/3] loss = 0.14747731387615204 acc = 0.9412\n",
      "Epoch[2400/5000] step[1/3] loss = 0.05668887495994568 acc = 1.0\n",
      "Epoch[2500/5000] step[1/3] loss = 0.030635563656687737 acc = 1.0\n",
      "Epoch[2600/5000] step[1/3] loss = 0.05699478089809418 acc = 1.0\n",
      "Epoch[2700/5000] step[1/3] loss = 0.0990242287516594 acc = 0.9412\n",
      "Epoch[2800/5000] step[1/3] loss = 0.06138835847377777 acc = 1.0\n",
      "Epoch[2900/5000] step[1/3] loss = 0.012975622899830341 acc = 1.0\n",
      "Epoch[3000/5000] step[1/3] loss = 0.02319648303091526 acc = 1.0\n",
      "Epoch[3100/5000] step[1/3] loss = 0.023314211517572403 acc = 1.0\n",
      "Epoch[3200/5000] step[1/3] loss = 0.006477725692093372 acc = 1.0\n",
      "Epoch[3300/5000] step[1/3] loss = 0.05033406987786293 acc = 1.0\n",
      "Epoch[3400/5000] step[1/3] loss = 0.024500420317053795 acc = 1.0\n",
      "Epoch[3500/5000] step[1/3] loss = 0.021417615935206413 acc = 1.0\n",
      "Epoch[3600/5000] step[1/3] loss = 0.01592055894434452 acc = 1.0\n",
      "Epoch[3700/5000] step[1/3] loss = 0.005139999091625214 acc = 1.0\n",
      "Epoch[3800/5000] step[1/3] loss = 0.006936040241271257 acc = 1.0\n",
      "Epoch[3900/5000] step[1/3] loss = 0.025589505210518837 acc = 1.0\n",
      "Epoch[4000/5000] step[1/3] loss = 0.033861178904771805 acc = 1.0\n",
      "Epoch[4100/5000] step[1/3] loss = 0.10482615232467651 acc = 0.9412\n",
      "Epoch[4200/5000] step[1/3] loss = 0.022425994277000427 acc = 1.0\n",
      "Epoch[4300/5000] step[1/3] loss = 0.09694712609052658 acc = 0.9412\n",
      "Epoch[4400/5000] step[1/3] loss = 0.03960435464978218 acc = 1.0\n",
      "Epoch[4500/5000] step[1/3] loss = 0.014901852235198021 acc = 1.0\n",
      "Epoch[4600/5000] step[1/3] loss = 0.12253379821777344 acc = 0.9412\n",
      "Epoch[4700/5000] step[1/3] loss = 0.12817253172397614 acc = 0.9412\n",
      "Epoch[4800/5000] step[1/3] loss = 0.016325047239661217 acc = 1.0\n",
      "Epoch[4900/5000] step[1/3] loss = 0.00017286979709751904 acc = 1.0\n",
      "Epoch[5000/5000] step[1/3] loss = 0.023859413340687752 acc = 1.0\n"
     ]
    }
   ],
   "source": [
    "epochs = 5000\n",
    "for i in range(epochs):\n",
    "    for idx, item in enumerate(trainset):\n",
    "        ipt = item[:, :2]\n",
    "        target = item[:, 2].long()\n",
    "        out = lp(ipt.long())\n",
    "        loss = loss_function(out, target)\n",
    "        ac = acc(out, target)\n",
    "        if (i + 1) % 100 == 0 and idx == 0:\n",
    "            print(f\"Epoch[{i + 1}/{epochs}] step[{idx + 1}/{len(trainset)}] loss = {loss.item()} acc = {round(ac, 4)}\")\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n",
      "Layer (type:depth-idx)                   Param #\n",
      "=================================================================\n",
      "├─Embedding: 1-1                         (352)\n",
      "├─Sequential: 1-2                        --\n",
      "|    └─Linear: 2-1                       2,080\n",
      "|    └─Dropout: 2-2                      --\n",
      "|    └─ReLU: 2-3                         --\n",
      "|    └─Linear: 2-4                       66\n",
      "=================================================================\n",
      "Total params: 2,498\n",
      "Trainable params: 2,146\n",
      "Non-trainable params: 352\n",
      "=================================================================\n"
     ]
    }
   ],
   "source": [
    "s = summary(lp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'7': 0,\n",
       " '3': 1,\n",
       " '4': 2,\n",
       " '1': 3,\n",
       " '6': 4,\n",
       " '8': 5,\n",
       " 'A': 6,\n",
       " '2': 7,\n",
       " '5': 8,\n",
       " '9': 9,\n",
       " 'B': 10}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinkPrediction(\n",
       "  (emb): Embedding(11, 32)\n",
       "  (model): Sequential(\n",
       "    (0): Linear(in_features=64, out_features=32, bias=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): ReLU()\n",
       "    (3): Linear(in_features=32, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lp.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 13.7926, -13.8530]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(lp(torch.Tensor([[word2idx['B'], word2idx['A']]]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-2.5851,  2.5586]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(lp(torch.Tensor([[word2idx['3'], word2idx['9']]]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-4.4713,  4.3654]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(lp(torch.Tensor([[word2idx['9'], word2idx['5']]]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 8.9413, -9.0097]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(lp(torch.Tensor([[word2idx['B'], word2idx['4']]]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 8.9067, -8.9923]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(lp(torch.Tensor([[word2idx['2'], word2idx['4']]]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
